{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libraries and setup matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '2'\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import sys\n",
    "sys.path.append('waveglow/')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import IPython.display as ipd\n",
    "import pickle as pkl\n",
    "from text import *\n",
    "import numpy as np\n",
    "import torch\n",
    "import hparams\n",
    "from modules.model import Model\n",
    "from denoiser import Denoiser\n",
    "import soundfile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from g2p_en import G2p\n",
    "from text.symbols import symbols\n",
    "from text.cleaners import custom_english_cleaners\n",
    "\n",
    "# Mappings from symbol to numeric ID and vice versa:\n",
    "symbol_to_id = {s: i for i, s in enumerate(symbols)}\n",
    "id_to_symbol = {i: s for i, s in enumerate(symbols)}\n",
    "\n",
    "g2p = G2p()\n",
    "def text2seq(text, data_type='char'):\n",
    "    text = custom_english_cleaners(text.rstrip())\n",
    "    if data_type=='phone':\n",
    "        clean_phone = []\n",
    "        for s in g2p(text.lower()):\n",
    "            if '@'+s in symbol_to_id:\n",
    "                clean_phone.append('@'+s)\n",
    "            else:\n",
    "                clean_phone.append(s)\n",
    "                text = clean_phone\n",
    "    \n",
    "    # Append SOS, EOS token\n",
    "    sequence = [symbol_to_id[c] for c in text]\n",
    "    sequence = [symbol_to_id['^']] + sequence + [symbol_to_id['~']]\n",
    "    return sequence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Waveglow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "waveglow_path = 'training_log/waveglow_256channels.pt'\n",
    "waveglow = torch.load(waveglow_path)['model']\n",
    "\n",
    "for m in waveglow.modules():\n",
    "    if 'Conv' in str(type(m)):\n",
    "        setattr(m, 'padding_mode', 'zeros')\n",
    "\n",
    "waveglow.cuda().eval()\n",
    "for k in waveglow.convinv:\n",
    "    k.float()\n",
    "\n",
    "denoiser = Denoiser(waveglow)\n",
    "\n",
    "with open('filelists/ljs_audio_text_val_filelist.txt', 'r') as f:\n",
    "    lines = [line.split('|') for line in f.read().splitlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for data_type in ['char', 'phone']:\n",
    "    for step in ['20000', '50000', '100000', '200000']:\n",
    "        checkpoint_path = f\"training_log/transformer-tts-{data_type}/checkpoint_{step}\"\n",
    "        state_dict = {}\n",
    "        for k, v in torch.load(checkpoint_path)['state_dict'].items():\n",
    "            state_dict[k[7:]]=v\n",
    "\n",
    "        model = Model(hparams).cuda()\n",
    "        model.load_state_dict(state_dict)\n",
    "        _ = model.cuda().eval()\n",
    "\n",
    "        for i in [1, 6, 22]:\n",
    "            file_name, _, text = lines[i]\n",
    "            sequence = np.array(text2seq(text,data_type))[None, :]\n",
    "            sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                melspec, enc_alignments, dec_alignments, enc_dec_alignments, stop = model.inference(sequence, max_len=2048)\n",
    "                melspec = melspec[:,:,:len(stop)]\n",
    "                audio = waveglow.infer(melspec, sigma=0.666)\n",
    "\n",
    "            soundfile.write(f'wavs/{file_name}_{data_type}{step}.wav', audio.cpu().numpy()[0].astype(float), 22050)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_pytorch_p36)",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
